mod directive;
mod state;

use std::collections::HashSet;

use common::AuthContext;
use grafbase_sdk::{
    AuthorizationExtension, IntoAuthorizeQueryOutput,
    host_io::{
        self,
        http::{HttpRequest, Url},
    },
    types::{
        AuthenticatedRequestContext, AuthorizationDecisions, AuthorizeQueryOutput, AuthorizedOperationContext,
        Configuration, DirectiveSite, Error, ErrorResponse, QueryElements, SubgraphHeaders,
    },
};

use directive::*;
use state::*;

#[derive(AuthorizationExtension)]
struct MyAuthorization {
    authorized_users_url: Url,
}

#[derive(serde::Deserialize)]
struct Config {
    auth_service_url: String,
}

impl AuthorizationExtension for MyAuthorization {
    fn new(config: Configuration) -> Result<Self, Error> {
        let Config { mut auth_service_url } = config.deserialize()?;
        auth_service_url.push_str("/authorized-users");
        let authorized_users_url = auth_service_url
            .parse()
            .map_err(|err| format!("Invalid authorized_users URL: {err}"))?;
        Ok(Self { authorized_users_url })
    }

    fn authorize_query(
        &mut self,
        ctx: &AuthenticatedRequestContext,
        _headers: &SubgraphHeaders,
        elements: QueryElements<'_>,
    ) -> Result<impl IntoAuthorizeQueryOutput, ErrorResponse> {
        // Deserialize the token that has been generated by our authentication extension.
        // We expect it to be present and properly serialized, if not we stop the request
        // processing.
        let Some(common::Token { current_user_id }) = ctx
            .token()
            .as_bytes()
            .and_then(|bytes| postcard::from_bytes(bytes).ok())
        else {
            return Err(ErrorResponse::internal_server_error());
        };

        // Builder will keep track of all of our authorization decisions. There are two simpler
        // variants `AuthorizationDecisions::grant_all()` and `AuthorizationDecisions::deny_all()`
        // for simpler cases.
        let mut builder = AuthorizationDecisions::deny_some_builder();
        let mut lazy_error_id = None;

        // We accumulate all the scopes with the subgraph which will need it.
        let mut required_jwt_scopes_accumulator = HashSet::new();
        // List of authorized user ids we lazily retrieve from the auth-service.
        let mut authorized_user_ids = None;

        let mut state = State::default();

        // Each element represents an object, field, enum, etc. within the query that was decorated
        // with one of our directives.
        for (directive_name, elements) in elements.iter_grouped_by_directive_name() {
            match directive_name {
                "jwtScope" => {
                    for element in elements {
                        let JwtScopeArguments { scopes } = element.directive_arguments::<JwtScopeArguments>()?;
                        required_jwt_scopes_accumulator.extend(scopes.into_iter().map(|scope| {
                            (
                                element
                                    .subgraph_name()
                                    .expect("extension is configured to receive subgraph name"),
                                scope,
                            )
                        }));
                    }
                }
                "accessControl" => {
                    for element in elements {
                        match element.directive_site() {
                            DirectiveSite::Object(object) => match object.name() {
                                "Account" => state.denied_ids.push(DeniedIds {
                                    query_element_id: element.id().into(),
                                    authorized_ids: if let Some(ids) = authorized_user_ids.as_ref() {
                                        ids
                                    } else {
                                        authorized_user_ids = Some(
                                            self.get_authorized_ids(current_user_id)
                                                .map_err(|err| ErrorResponse::unauthorized().with_error(err))?,
                                        );
                                        authorized_user_ids.as_ref().unwrap()
                                    }
                                    .clone(),
                                }),
                                _ => {
                                    return Err(unsupported());
                                }
                            },
                            DirectiveSite::FieldDefinition(field) => {
                                match (field.parent_type_name(), field.name()) {
                                    ("Query", "user") | ("Mutation", "updateUser") => {
                                        let AccessControlArguments { arguments, .. } = element.directive_arguments()?;
                                        let arguments = arguments.unwrap();
                                        let ids = if let Some(ids) = authorized_user_ids.as_ref() {
                                            ids
                                        } else {
                                            authorized_user_ids = Some(self.get_authorized_ids(current_user_id)?);
                                            authorized_user_ids.as_ref().unwrap()
                                        };
                                        if !ids.contains(&arguments.id_as_u32()) {
                                            let error_id = *lazy_error_id.get_or_insert_with(|| {
                                                builder.push_error("Not authorized: cannot access user")
                                            });
                                            // We re-use the same GraphQL error here to avoid sending duplicate data back to
                                            // the gateway. The GraphQL response will have an individual error for each element
                                            // however.
                                            builder.deny_with_error_id(element, error_id);
                                        }
                                    }
                                    _ => {
                                        return Err(unsupported());
                                    }
                                }
                            }
                            _ => unreachable!(),
                        }
                    }
                }
                _ => unreachable!(),
            }
        }

        let mut out_ctx = AuthContext::default();
        for (subgraph_name, scope) in required_jwt_scopes_accumulator {
            out_ctx
                .scopes
                .entry(subgraph_name)
                .and_modify(|value| {
                    value.push('.');
                    value.push_str(scope);
                })
                .or_insert_with(|| scope.to_string());
        }

        // For a simpler alternative that works with serde, we recommend `postcard`. rkyv has the
        // benefit of proving zero-copy deserialization. And authorize-response may be called
        // multiple times contrary to authorize_query which is called only once if necessary.
        let state = rkyv::api::high::to_bytes_in::<_, rkyv::rancor::Error>(&state, Vec::new()).unwrap();

        let out_ctx = postcard::to_allocvec(&out_ctx).unwrap();
        Ok(AuthorizeQueryOutput::new(builder.build()).state(state).context(out_ctx))
    }

    fn authorize_response(
        &mut self,
        _ctx: &AuthorizedOperationContext,
        state: Vec<u8>,
        elements: grafbase_sdk::types::ResponseElements<'_>,
    ) -> Result<AuthorizationDecisions, Error> {
        let state = rkyv::access::<ArchivedState, rkyv::rancor::Error>(&state).unwrap();

        let mut builder = AuthorizationDecisions::deny_some_builder();
        let mut lazy_error_id = None;

        // Each element here matches one of the query elements we received in authorize_query. But
        // we only receive them here if and only if the directive requested something from the
        // response, with a `FieldSet` typically.
        for element in elements {
            if let Some(denied) = state
                .denied_ids
                .iter()
                .find(|denied| denied.query_element_id == u32::from(element.query_element_id()))
            {
                // Each item here represents an item within the GraphQL response. So if the query
                // was something like `query { users { secret } }`. We'll only receive one element
                // for `users` field or `User` type, etc. depending on the directive location. But
                // an item for every occurrence of the user within the response.
                for item in element.items() {
                    let AccessControlArguments { fields, .. } = item.directive_arguments()?;
                    let fields = fields.unwrap();
                    if !denied.authorized_ids.contains(&(fields.id_as_u32().into())) {
                        let error_id = *lazy_error_id
                            .get_or_insert_with(|| builder.push_error("Not authorized: cannot access account"));
                        builder.deny_with_error_id(item, error_id);
                    }
                }
            }
        }

        Ok(builder.build())
    }
}

impl MyAuthorization {
    fn get_authorized_ids(&self, current_user_id: u32) -> Result<Vec<u32>, Error> {
        #[derive(serde::Serialize)]
        struct Request {
            current_user_id: u32,
        }

        #[derive(serde::Deserialize)]
        struct Response {
            authorized_users: Vec<u32>,
        }

        let response = host_io::http::execute(
            HttpRequest::post(self.authorized_users_url.clone()).json(&Request { current_user_id }),
        )
        .map_err(|err| {
            log::error!("Failed to fetch policies: {err}");
            "Unauthorized"
        })?;

        let Response { authorized_users } = response.json().map_err(|err| {
            log::error!("Failed to parse policy response: {err}");
            "Unauthorized"
        })?;

        Ok(authorized_users)
    }
}

fn unsupported() -> ErrorResponse {
    ErrorResponse::internal_server_error().with_error(Error::new("Unsupported"))
}
